syntax = "proto3";

package tensorflow.tpu;

import "tensorflow/compiler/xla/xla.proto";
import "tensorflow/compiler/xla/xla_data.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
import "tensorflow/core/protobuf/tpu/dynamic_padding.proto";

option cc_enable_arenas = true;

// This is an experimental proto used in the TF/XLA bridge to store metadata to
// a compile op (e.g. _TPUCompileMlir).
// TODO(lyandy): Deprecate proto once generic metadata proto is created.
message TPUCompileMetadataProto {
  // Description of the types and shapes of the arguments to a computation.
  message Arg {
    enum Kind {
      INVALID = 0;
      PARAMETER = 1;
      VARIABLE = 2;
      // These are args which have been guaranteed to be constants during the
      // session lifetime by the use of the GuaranteeConstOp (or ConstantOp).
      GUARANTEED_CONSTANT = 3;
    }
    DataType dtype = 1;
    TensorShapeProto shape = 2;
    Kind kind = 3;

    // The cross-core sharding of this input within each replica, e.g.,
    // assigning to one core, or replicate across all cores.
    xla.OpSharding sharding = 4;

    // Whether this argument will receive the same data across all replicas.
    bool is_same_data_across_replicas = 5;

    enum EnableXlaSharding {
      DISALLOWED = 0;
      // Sharding is allowed if host training loop exists.
      TENTATIVE = 1;
      ALLOWED = 2;
    }
    // Whether to allow XLA to produce separate programs to shard/unshard this
    // argument. Requires this arg to be an on-device Kind::VARIABLE, or a
    // Kind::PARAMETER. For Kind::PARAMETER, it represents the initial value of
    // a variable, and retval_index_for_sharding must be specified for the
    // corresponding updated value.
    EnableXlaSharding enable_xla_sharding = 6;

    // If XLA sharding is allowed on a Kind::PARAMETER, this field is used to
    // specify the corresponding updated value in the return values. Use -1 for
    // variables that are not updated.
    int32 retval_index_for_sharding = 8;

    // Whether this argument is placed on fast memory or not.
    bool fast_mem = 7;

    // Whether to let XLA to decide the layout during compilation, as opposed to
    // using a fixed layout determined by the shape.
    bool unrestricted_layout = 9;

    // Name of the node that the arg comes from.
    string name = 10;

    // Whether to use XLA collectives to broadcast this parameter to all
    // replicas, instead of using TensorFlow Send/Recv among the tasks.
    bool requires_xla_broadcast = 11;

    // Which dimension of this arg is bounded dynamic.
    repeated bool is_bounded_dynamic_dim = 12;
  }
  repeated Arg args = 1;

  // Description of the return values from a computation.
  message Retval {
    // The cross-core sharding of this return value within each replica, e.g.,
    // assigning to one core, or replicate across all cores.
    xla.OpSharding sharding = 1;
  }
  repeated Retval retvals = 2;

  // Number of replicas of the computation and number of cores in each replica.
  // TODO(b/140721404): it may not be necessary to state the number of cores per
  // replica here. Reconsider when replicated model-parallelism is implemented
  // in XLA.
  int32 num_replicas = 3;
  int32 num_cores_per_replica = 4;

  reserved 5;  // was device_names
  reserved 7;  // was replica_device_assignment

  xla.DeviceAssignmentProto device_assignment = 8;

  // A fingerprint of the function library. Ensures that any functions called
  // by the computation have matching definitions.
  uint64 function_library_fingerprint = 6;

  // Unique session identifier. Can be empty.
  string session_handle = 9;

  // Fingerprint of guaranteed_const value. The fingerprint computation inside
  // tpu_compile_op may be slow. The computation can be avoided by setting the
  // fingerprint value here.
  string guaranteed_const_fingerprint = 10;

  repeated tpu.PaddingMap padding_maps = 11;

  // The location of step markers that XLA compile will instrument.
  xla.DebugOptions.StepMarkerLocation step_marker_location = 12;

  // Minimum number of batches run through the XLA graph before XLA fusion
  // autotuner is enabled. Default value of zero disables the autotuner.
  // The XLA fusion autotuner can improve performance by executing a heuristic
  // search on the compiler parameters.
  int64 xla_fusion_autotuner_thresh = 13;

  // Enables TPU compiler to add partitioning policies for inputs/outputs to
  // the XLA computation for model parallelism.
  bool enable_automatic_model_parallelism = 14;

  // Whether to use XLA's SPMD or MPMD partitioner when compiler partitioning is
  // requested.
  bool use_spmd_for_xla_partitioning = 15;

  // Whether to automatically generate XLA shardings for SPMD partitioner.
  bool use_auto_spmd_for_xla_partitioning = 18;

  // Device mesh shape used to create the sharding search space when
  // use_auto_spmd_partitioning=true.
  repeated int64 auto_spmd_mesh_shape = 19;

  // Device mesh ids compatible with the above mesh_shape used when
  // use_auto_spmd_partitioning=true.
  repeated int64 auto_spmd_mesh_ids = 20;

  reserved 16;  // Was broadcast_replicated_parameters_via_collectives

  // A fingerprint generated by hashing the MLIR module content.
  uint64 mlir_fingerprint = 17;

  TPUCompileOptions compile_options = 21;

  // The name of the MLIR module.
  string module_name = 22;
}

// Stable protobuf for TPU compilation options, suitable for persistent storage.
// This proto needs to be backward compatible under maintenance.
// TODO(timshen): investigate and migrate other options from
// TPUCompileMetadataProto.
message TPUCompileOptions {
  enum Precision {
    DEFAULT = 0;
    BFLOAT16 = 1;
    FLOAT32 = 2;
    TENSOR_FLOAT32 = 3;
  }
  Precision matrix_unit_operand_precision = 1;
}
